/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.web.socket.server.standard;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.EndpointConfig;
import javax.websocket.Extension;
import javax.websocket.WebSocketContainer;

import org.glassfish.tyrus.core.ComponentProviderService;
import org.glassfish.tyrus.core.RequestContext;
import org.glassfish.tyrus.core.TyrusEndpointWrapper;
import org.glassfish.tyrus.core.TyrusUpgradeResponse;
import org.glassfish.tyrus.core.TyrusWebSocketEngine;
import org.glassfish.tyrus.core.Version;
import org.glassfish.tyrus.server.TyrusServerContainer;
import org.glassfish.tyrus.spi.WebSocketEngine.UpgradeInfo;

import org.springframework.beans.DirectFieldAccessor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.server.HandshakeFailureException;

import static org.glassfish.tyrus.spi.WebSocketEngine.UpgradeStatus.SUCCESS;

/**
 * A base class for {@code RequestUpgradeStrategy} implementations on top of
 * JSR-356 based servers which include Tyrus as their WebSocket engine.
 *
 * <p>Works with Tyrus 1.11 (WebLogic 12.2.1) and Tyrus 1.12 (GlassFish 4.1.1).
 *
 * @author Rossen Stoyanchev
 * @author Brian Clozel
 * @author Juergen Hoeller
 * @see <a href="https://tyrus.java.net/">Project Tyrus</a>
 * @since 4.1
 */
public abstract class AbstractTyrusRequestUpgradeStrategy extends AbstractStandardUpgradeStrategy {

    private static final Random random = new Random();

    private static final Constructor<?> constructor;

    private static boolean constructorWithBooleanArgument;

    private static final Method registerMethod;

    private static final Method unRegisterMethod;

    static {
        try {
            constructor = getEndpointConstructor();
            int parameterCount = constructor.getParameterCount();
            constructorWithBooleanArgument = (parameterCount == 10);
            if (!constructorWithBooleanArgument && parameterCount != 9) {
                throw new IllegalStateException("Expected TyrusEndpointWrapper constructor with 9 or 10 arguments");
            }
            registerMethod = TyrusWebSocketEngine.class.getDeclaredMethod("register", TyrusEndpointWrapper.class);
            unRegisterMethod = TyrusWebSocketEngine.class.getDeclaredMethod("unregister", TyrusEndpointWrapper.class);
            ReflectionUtils.makeAccessible(registerMethod);
        } catch (Exception ex) {
            throw new IllegalStateException("No compatible Tyrus version found", ex);
        }
    }

    private static Constructor<?> getEndpointConstructor() {
        for (Constructor<?> current : TyrusEndpointWrapper.class.getConstructors()) {
            Class<?>[] types = current.getParameterTypes();
            if (Endpoint.class == types[0] && EndpointConfig.class == types[1]) {
                return current;
            }
        }
        throw new IllegalStateException("No compatible Tyrus version found");
    }


    private final ComponentProviderService componentProvider = ComponentProviderService.create();


    @Override
    public String[] getSupportedVersions() {
        return StringUtils.tokenizeToStringArray(Version.getSupportedWireProtocolVersions(), ",");
    }

    @Override
    protected List<WebSocketExtension> getInstalledExtensions(WebSocketContainer container) {
        try {
            return super.getInstalledExtensions(container);
        } catch (UnsupportedOperationException ex) {
            return new ArrayList<>(0);
        }
    }

    @Override
    public void upgradeInternal(ServerHttpRequest request, ServerHttpResponse response,
                                @Nullable String selectedProtocol, List<Extension> extensions, Endpoint endpoint)
            throws HandshakeFailureException {

        HttpServletRequest servletRequest = getHttpServletRequest(request);
        HttpServletResponse servletResponse = getHttpServletResponse(response);

        TyrusServerContainer serverContainer = (TyrusServerContainer) getContainer(servletRequest);
        TyrusWebSocketEngine engine = (TyrusWebSocketEngine) serverContainer.getWebSocketEngine();
        Object tyrusEndpoint = null;
        boolean success;

        try {
            // Shouldn't matter for processing but must be unique
            String path = "/" + random.nextLong();
            tyrusEndpoint = createTyrusEndpoint(endpoint, path, selectedProtocol, extensions, serverContainer, engine);
            register(engine, tyrusEndpoint);

            HttpHeaders headers = request.getHeaders();
            RequestContext requestContext = createRequestContext(servletRequest, path, headers);
            TyrusUpgradeResponse upgradeResponse = new TyrusUpgradeResponse();
            UpgradeInfo upgradeInfo = engine.upgrade(requestContext, upgradeResponse);
            success = SUCCESS.equals(upgradeInfo.getStatus());
            if (success) {
                if (logger.isTraceEnabled()) {
                    logger.trace("Successful request upgrade: " + upgradeResponse.getHeaders());
                }
                handleSuccess(servletRequest, servletResponse, upgradeInfo, upgradeResponse);
            }
        } catch (Exception ex) {
            unregisterTyrusEndpoint(engine, tyrusEndpoint);
            throw new HandshakeFailureException("Error during handshake: " + request.getURI(), ex);
        }

        unregisterTyrusEndpoint(engine, tyrusEndpoint);
        if (!success) {
            throw new HandshakeFailureException("Unexpected handshake failure: " + request.getURI());
        }
    }

    private Object createTyrusEndpoint(Endpoint endpoint, String endpointPath, @Nullable String protocol,
                                       List<Extension> extensions, WebSocketContainer container, TyrusWebSocketEngine engine)
            throws DeploymentException {

        ServerEndpointRegistration endpointConfig = new ServerEndpointRegistration(endpointPath, endpoint);
        endpointConfig.setSubprotocols(Collections.singletonList(protocol));
        endpointConfig.setExtensions(extensions);
        return createEndpoint(endpointConfig, this.componentProvider, container, engine);
    }

    private RequestContext createRequestContext(HttpServletRequest request, String endpointPath, HttpHeaders headers) {
        RequestContext context =
                RequestContext.Builder.create()
                        .requestURI(URI.create(endpointPath))
                        .userPrincipal(request.getUserPrincipal())
                        .secure(request.isSecure())
                        .remoteAddr(request.getRemoteAddr())
                        .build();
        headers.forEach((header, value) -> context.getHeaders().put(header, value));
        return context;
    }

    private void unregisterTyrusEndpoint(TyrusWebSocketEngine engine, @Nullable Object tyrusEndpoint) {
        if (tyrusEndpoint != null) {
            try {
                unregister(engine, tyrusEndpoint);
            } catch (Throwable ex) {
                // ignore
            }
        }
    }

    private Object createEndpoint(ServerEndpointRegistration registration, ComponentProviderService provider,
                                  WebSocketContainer container, TyrusWebSocketEngine engine) throws DeploymentException {

        DirectFieldAccessor accessor = new DirectFieldAccessor(engine);
        Object sessionListener = accessor.getPropertyValue("sessionListener");
        Object clusterContext = accessor.getPropertyValue("clusterContext");
        try {
            if (constructorWithBooleanArgument) {
                // Tyrus 1.11+
                return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
                        "/", registration.getConfigurator(), sessionListener, clusterContext, null, Boolean.TRUE);
            } else {
                return constructor.newInstance(registration.getEndpoint(), registration, provider, container,
                        "/", registration.getConfigurator(), sessionListener, clusterContext, null);
            }
        } catch (Exception ex) {
            throw new HandshakeFailureException("Failed to register " + registration, ex);
        }
    }

    private void register(TyrusWebSocketEngine engine, Object endpoint) {
        try {
            registerMethod.invoke(engine, endpoint);
        } catch (Exception ex) {
            throw new HandshakeFailureException("Failed to register " + endpoint, ex);
        }
    }

    private void unregister(TyrusWebSocketEngine engine, Object endpoint) {
        try {
            unRegisterMethod.invoke(engine, endpoint);
        } catch (Exception ex) {
            throw new HandshakeFailureException("Failed to unregister " + endpoint, ex);
        }
    }


    protected abstract void handleSuccess(HttpServletRequest request, HttpServletResponse response,
                                          UpgradeInfo upgradeInfo, TyrusUpgradeResponse upgradeResponse) throws IOException, ServletException;

}
